# Copyright 2017 Rice University
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function
import argparse
import sys
import json
import math
import random
import numpy as np
from itertools import chain

import bayou.models.low_level_evidences.evidence
from bayou.models.low_level_evidences.utils import gather_calls

HELP = """Use this script to extract evidences from a raw data file with sequences generated by driver.
You can also filter programs based on number and length of sequences, and control the samples from each program."""


def extract_evidence(clargs):
    print('Loading data file...', end='')
    with open(clargs.input_file[0]) as f:
        js = json.load(f)
    print('done')
    done = 0
    programs = []
    for program in js['programs']:
        sequences = program['sequences']
        if len(sequences) > clargs.max_seqs or \
                any([len(sequence['calls']) > clargs.max_seq_length for sequence in sequences]):
            continue

        calls = gather_calls(program['ast'])

        apicalls = list(set(chain.from_iterable([bayou.models.low_level_evidences.evidence.APICalls.from_call(call)
                                                 for call in calls])))
        types = list(set(chain.from_iterable([bayou.models.low_level_evidences.evidence.Types.from_call(call)
                                              for call in calls])))
        keywords = list(set(chain.from_iterable([bayou.models.low_level_evidences.evidence.Keywords.from_call(call)
                                                for call in calls])))

        if clargs.num_samples == 0:
            program['apicalls'] = apicalls
            program['types'] = types
            program['keywords'] = keywords
            programs.append(program)
        else:
            # put all evidences in the same bag (to avoid bias during sampling)
            evidences = [(e, 'apicalls') for e in apicalls] + [(e, 'types') for e in types] + \
                        [(e, 'keywords') for e in keywords]
            num_samples = clargs.num_samples if clargs.num_samples > 0 else math.ceil(len(evidences)/-clargs.num_samples)

            for i in range(num_samples):
                sample = dict(program)
                sample['apicalls'] = []
                sample['types'] = []
                sample['keywords'] = []

                if clargs.observability is not None:
                    observability = clargs.observability if clargs.observability > 0 else random.randint(1, 100)
                    choices = random.sample(evidences, math.ceil(len(evidences) * observability / 100))
                elif clargs.distribution is not None:
                    random.shuffle(evidences)
                    num = np.random.choice(range(len(clargs.distribution)), p=clargs.distribution)
                    choices = evidences[:num+1]
                else:
                    raise ValueError('Invalid option for sampling')

                for choice, evidence in choices:
                    sample[evidence].append(choice)
                programs.append(sample)

        done += 1
        print('Extracted evidence for {} programs'.format(done), end='\r')

    print('\nWriting to {}...'.format(clargs.output_file[0]), end='')
    with open(clargs.output_file[0], 'w') as f:
        json.dump({'programs': programs}, fp=f, indent=2)
    print('done')


if __name__ == '__main__':
    parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter,
                                     description=HELP)
    parser.add_argument('input_file', type=str, nargs=1,
                        help='input data file')
    parser.add_argument('output_file', type=str, nargs=1,
                        help='output data file')
    parser.add_argument('--python_recursion_limit', type=int, default=10000,
                        help='set recursion limit for the Python interpreter')
    parser.add_argument('--max_seqs', type=int, default=9999,
                        help='maximum number of sequences in a program')
    parser.add_argument('--max_seq_length', type=int, default=9999,
                        help='maximum length of each sequence in a program')
    parser.add_argument('--num_samples', type=int, default=0,
                        help='number of samples per program (< 0 = adaptive, e.g., -k = number of evidences/k)')
    parser.add_argument('--observability', type=int, default=None,
                        help='percentage of observable evidence (e.g., 100, 75, 50, etc.. 0 = random)')
    parser.add_argument('--distribution', nargs='+', type=float, default=None,
                        help='distribution over number of evidences in each sample (e.g., 0.3 0.5 0.2). Must sum to 1.')
    clargs = parser.parse_args()
    sys.setrecursionlimit(clargs.python_recursion_limit)
    if clargs.num_samples > 0:
        if clargs.observability is not None and clargs.distribution is not None:
            parser.error('Provide exactly one of --observability or --distribution')
        if clargs.observability is None and clargs.distribution is None:
            parser.error('Provide exactly one of --observability or --distribution')
    extract_evidence(clargs)
